#!/usr/bin/env python
"""Run evaluation on a single image, using a baseline neural network"""

from nbdt.model import SoftNBDT, HardNBDT
from pytorchcv.models.wrn_cifar import wrn28_10_cifar10
from torchvision import transforms
from nbdt.utils import DATASET_TO_CLASSES, load_image_from_path
from nbdt.thirdparty.wn import maybe_install_wordnet
import torch.nn.functional as F
from torch.distributions import Categorical
import sys

maybe_install_wordnet()

assert len(sys.argv) > 1, "Need to pass image URL or image path as argument"

# load pretrained model
model = wrn28_10_cifar10(pretrained=True)
model.eval()

# load + transform image
im = load_image_from_path(sys.argv[1])
transform = transforms.Compose(
    [
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)
x = transform(im)[None]

model.eval()

probs = F.softmax(model(x), dim=1)[0]
confidence = (1 - Categorical(probs=probs).entropy()) * 100.0

print(
    "Probabilities per class: "
    + ", ".join(
        [
            f"{cls} ({p.item() * 100:.2f}%)"
            for p, cls in sorted(
                zip(probs, DATASET_TO_CLASSES["CIFAR10"]),
                key=lambda t: t[0],
                reverse=True,
            )
        ]
    )
    + f"// Confidence: {confidence:.2f}%"
)
